import torch
import numpy as np
import matplotlib.pyplot as plt

import random
from collections import defaultdict
from typing import OrderedDict
from functools import partial
import argparse
from numpy.linalg import norm

import sys

# import packages from parent directory
sys.path.append('..')
from optimizer.DAS2C import TiAda, TiAda_Adam, TiAda_wo_max

from tensorboard_logger import Logger

# Argument
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=8, help='random seed')
parser.add_argument('--n_iter', type=int, default=3000, help='number of gradient calls')
parser.add_argument('--lr_y', type=float, default=0.01, help='learning rate of y')
parser.add_argument('--r', type=float, default=1, help='stepsize of y / stepsize of x')
parser.add_argument('--init_x', type=float, default=0, help='init value of x')
parser.add_argument('--init_y', type=float, default=0, help='init value of y')
parser.add_argument('--grad_noise_y', type=float, default=0.00, help='gradient noise variance')
parser.add_argument('--grad_noise_x', type=float, default=0.00, help='gradient noise variance')

parser.add_argument('--func', type=str, default='quadratic', help='function name')
parser.add_argument('--L_ave', type=float, default=1, help='parameter for the test function')

parser.add_argument('--optim', type=str, default='TiAda', help='optimizer')
parser.add_argument('--alpha', type=float, default=0.6, help='parameter for TiAda')
parser.add_argument('--beta', type=float, default=0.4, help='parameter for TiAda')
parser.add_argument('--tracking', type=int, default=0, help='gradient noise variance')

args = parser.parse_args()

# Set precision to 64
torch.set_default_dtype(torch.float64)

n = 2  # number of nodes
L = torch.Tensor(np.array([[1], [2]]))

L_ave = 1.5

matrix = np.array([[2/3, 1/3],
                   [1/3, 2/3]])

W = torch.tensor(matrix)

# Different functions
functions = OrderedDict()


functions["quadratic"] = {
    "func":
        lambda x, y: -9 / 20 * (y ** 2) + L * x * y - (L ** 2 / 2) * (x ** 2) - x + 3/5 * y,
    "grad_x":
        lambda x, y: L*y -(L**2)*x - 1,
    "grad_y":
        lambda x, y: -9/10 * y + L*x + 3/5,
}

# optimizers = OrderedDict()
# optimizers["TiAda"] = TiAda

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

n_iter = args.n_iter
ratio = args.r

print(f"Function: {args.func}")
print(f"Optimizer: {args.optim}")
fun = functions[args.func]["func"]
grad_x = functions[args.func]["grad_x"]
grad_y = functions[args.func]["grad_y"]

if args.func == "McCormick":
    dim = 2
else:
    dim = 1

# Tensorboard
filename = f"./logs/{args.optim}_"
if args.func == 'quadratic':
    filename += f"L{L_ave}"
else:
    filename += f"{args.func}"
filename += f"_r_{ratio}_lry_{args.lr_y}"
if 'TiAda' in args.optim:
    filename += f"_a_{args.alpha}_b_{args.beta}"
if args.grad_noise_x != 0:
    filename += f"_noisex_{args.grad_noise_x}"
if args.grad_noise_y != 0:
    filename += f"_noisey_{args.grad_noise_y}"
if args.tracking == 1:
    filename += f"_tracking"

logger = Logger(filename)

# learning rate
lr_y = args.lr_y
lr_x = lr_y / ratio

scale_y = np.ones((n, 1))
scale_x = np.ones((n, 1))


if args.init_x is None:
    init_x = torch.randn(n, dim)
else:
    init_x = torch.Tensor([args.init_x]*np.ones((n, dim)))
if args.init_y is None:
    init_y = torch.randn(n, dim)
else:
    init_y = torch.Tensor([args.init_y]*np.ones((n, dim)))
if args.func != 'bilinear':
    print(f"init x: {init_x}, init y: {init_y}")

x = torch.nn.parameter.Parameter(init_x.clone())
y = torch.nn.parameter.Parameter(init_y.clone())

if args.optim == 'TiAda':
    optim_y = TiAda([y], lr=lr_y, alpha=args.beta)
    optim_x = TiAda([x], opponent_optim=optim_y, lr=lr_x, alpha=args.alpha)

i = 0

# calculate the parameters to be recorded
x_ave = torch.mean(x)
y_ave = torch.mean(y)
x_grad_ave = grad_x(x_ave, y_ave)
x_grad_ave_norm = (torch.mean(x_grad_ave).item()) ** 2

logger.scalar_summary('x_grad_ave', step=i, value=x_grad_ave_norm)
logger.scalar_summary('x_ave', step=i, value=x_ave.item())
logger.scalar_summary('y_ave', step=i, value=y_ave.item())

zeta_u = norm((np.mean(scale_y) ** args.beta) / (scale_y ** args.beta) - 1) / n
zeta_v = norm((np.mean(scale_x) ** args.alpha) / (scale_x ** args.alpha) - 1) / n
logger.scalar_summary('zeta_v', step=i, value=zeta_v)


outer_loop_count = 0
while i < n_iter:
    if args.optim == 'TiAda':  # other single-loop optimizers
        l = fun(x, y)
        x.grad = grad_x(x, y)
        y.grad = grad_y(x, y)
        # stocastic gradient
        y.grad = y.grad + args.grad_noise_y * torch.randn(n, dim)
        x.grad = x.grad + args.grad_noise_x * torch.randn(n, dim)

        # calculate the parameters to be recorded
        x_ave = torch.mean(x)
        y_ave = torch.mean(y)
        x_grad_ave = grad_x(x_ave, y_ave)
        x_grad_ave_norm = (torch.mean(x_grad_ave).item()) ** 2

        logger.scalar_summary('x_grad_ave', step=i, value=x_grad_ave_norm)
        logger.scalar_summary('x_ave', step=i, value=x_ave.item())
        logger.scalar_summary('y_ave', step=i, value=y_ave.item())

        zeta_u = norm((np.mean(scale_y) ** args.beta) / (scale_y ** args.beta) - 1) / n
        zeta_v = norm((np.mean(scale_x) ** args.alpha) / (scale_x ** args.alpha) - 1) / n
        logger.scalar_summary('zeta_v', step=i, value=zeta_v)


        i += 2
        y_grad_norm = y.grad.detach().numpy()**2
        x_grad_norm = x.grad.detach().numpy()**2

        scale_y += y_grad_norm * y_grad_norm
        scale_x += x_grad_norm * x_grad_norm
        scale_x = np.maximum(scale_x, scale_y)


        y = y + lr_y*torch.tensor(scale_y**(-args.beta))*y.grad
        x = x - lr_x*torch.tensor(scale_x**(-args.alpha))*x.grad

        y = torch.matmul(W, y)  # y communication
        x = torch.matmul(W, x)  # x communication
        if args.tracking == 1:
            scale_y = np.dot(W, scale_y)  # V communication
            scale_x = np.dot(W, scale_x)  # U communication



        if i == n_iter-2:
            print('Scalar_x_y:\n x', scale_x, '\n y', scale_y)
            print('Value_x_y:\n x', x, '\n y', y)

    if x_grad_ave_norm > 1e4:
        break